<!DOCTYPE html>
<html>
<head>
  <script src="face-api.js"></script>
  <script src="commons.js"></script>
  <script src="js/randomCrop.js"></script>
  <script src="js/indexedDb.js"></script>
  <script src="js/synchronizeDb.js"></script>
  <script src="FileSaver.js"></script>
</head>
<body>
  <div id="container"></div>

  <script>
    tf = faceapi.tf

    // uri to weights file of last checkpoint
    window.net = new faceapi.FaceLandmark68TinyNet()
    const modelCheckpoint = 'tmp/face_landmark_68_tiny_model_lr00001_19.weights'
    const startEpoch = 20

    const learningRate = 0.0001 // 0.001
    window.optimizer = tf.train.adam(learningRate, 0.9, 0.999, 1e-8)

    window.saveEveryNthSample = Infinity

    window.withRandomCrop = true

    window.batchSize = 12
    //window.batchSize = 32

    window.lossValues = []

    window.iterDelay = 0
    window.withLogging = true

    const log = (str, ...args) => console.log(`[${[(new Date()).toTimeString().substr(0, 8)]}] ${str || ''}`, ...args)

    function saveWeights(net, filename = 'train_tmp') {
      saveAs(new Blob([net.serializeParams()]), filename)
    }

    async function loadNetWeights(uri) {
      return new Float32Array(await (await fetch(uri)).arrayBuffer())
    }

    async function delay(ms) {
      return new Promise(res => setTimeout(res, ms))
    }

    async function fetchTrainDataIds() {
      return fetch('/face_landmarks_train_ids').then(res => res.json())
    }

    async function load() {
      window.trainIds = (await fetchTrainDataIds())
      const weights = await loadNetWeights(modelCheckpoint)
      await window.net.load(weights)
      window.net.variable()
    }

    async function onEpochDone(epoch) {
      saveWeights(window.net, `face_landmark_68_tiny_model_lr00001_${epoch}.weights`)

      const loss = window.lossValues[epoch]
      saveAs(new Blob([JSON.stringify({ loss, avgLoss: loss / window.trainIds.length })]), `face_landmark_68_tiny_model_lr00001_${epoch}.json`)

    }

    async function train() {
      await load()

      for (let epoch = startEpoch; epoch < Infinity; epoch++) {

        if (epoch !== startEpoch) {
          // ugly hack to wait for loss datas for that epoch to be resolved
          setTimeout(() => onEpochDone(epoch - 1), 10000)
        }
        window.lossValues[epoch] = 0

        const shuffledInputs = faceapi.shuffleArray(window.trainIds)

        for (let dataIdx = 0; dataIdx < shuffledInputs.length; dataIdx += window.batchSize) {
          const tsIter = Date.now()

          const ids = shuffledInputs.slice(dataIdx, dataIdx + window.batchSize)

          // fetch image and ground truth bounding boxes
          let tsFetch = Date.now()
          let tsFetchPts = Date.now()

          const allPts = (await getPts(ids)).map(res => res.pts)

          tsFetchPts = Date.now() - tsFetchPts

          let tsFetchJpgs = Date.now()

          const allJpgBlobs = (await getJpgs(ids)).map(res => res.blob)

          tsFetchJpgs = Date.now() - tsFetchJpgs

          let tsBufferToImage = Date.now()

          const allJpgs = await Promise.all(
            allJpgBlobs.map(blob => faceapi.bufferToImage(blob))
          )

          tsBufferToImage = Date.now() - tsBufferToImage

          tsFetch = Date.now() - tsFetch

          const bImages = []
          const bPts = []

          for (let batchIdx = 0; batchIdx < ids.length; batchIdx += 1) {

            const pts = allPts[batchIdx]
            const img = allJpgs[batchIdx]

            const { croppedImage, shiftedPts } = window.withRandomCrop
              ? randomCrop(img, pts)
              : { croppedImage: img, shiftedPts: pts }

            const squaredImg = faceapi.imageToSquare(croppedImage, 112, true)

            const reshapedDimensions = faceapi.computeReshapedDimensions(croppedImage, 112)

            const pad = Math.abs(reshapedDimensions.width - reshapedDimensions.height) /  (2 * 112)
            const padX = reshapedDimensions.width < reshapedDimensions.height ? pad : 0
            const padY = reshapedDimensions.height < reshapedDimensions.width ? pad : 0

            const groundTruthLandmarks = shiftedPts.map(pt => new faceapi.Point(
              padX + (pt.x / croppedImage.width) * (reshapedDimensions.width / 112),
              padY + (pt.y / croppedImage.height)* (reshapedDimensions.height / 112)
            ))

            bImages.push(squaredImg)
            bPts.push(groundTruthLandmarks)
          }

          const netInput = await faceapi.toNetInput(bImages)

          let tsBackward = Date.now()
          let tsForward = Date.now()

          const loss = optimizer.minimize(() => {
            const out = window.net.runNet(netInput)
            tsForward = Date.now() - tsForward
            tsBackward = Date.now()
            const landmarksTensor = tf.tensor2d(
              bPts
                .reduce((flat, arr) => flat.concat(arr))
                .map(pt => [pt.x, pt.y])
                .reduce((flat, arr) => flat.concat(arr)),
              [ids.length, 136]
            )

            const loss = tf.losses.meanSquaredError(
              landmarksTensor,
              out,
              tf.Reduction.MEAN
            )

            return loss
          }, true)
          tsBackward = Date.now() - tsBackward

          // start next iteration without waiting for loss data

          loss.data().then(data => {
            const lossValue = data[0]
            window.lossValues[epoch] += lossValue
            window.withLogging && log(`epoch ${epoch}, dataIdx ${dataIdx} - loss: ${lossValue}, ( ${window.lossValues[epoch]})`)
            loss.dispose()
          })

          window.withLogging && log(`epoch ${epoch}, dataIdx ${dataIdx} - forward: ${tsForward} ms, backprop: ${tsBackward} ms, iter: ${Date.now() - tsIter} ms`)
          if (window.logsilly)  {
            log(`fetch: ${tsFetch} ms, pts: ${tsFetchPts} ms, jpgs: ${tsFetchJpgs} ms, bufferToImage: ${tsBufferToImage} ms`)
          }
          if (window.iterDelay) {
            await delay(window.iterDelay)
          } else {
            await tf.nextFrame()
          }
        }
      }
    }

  </script>
</body>
</html>